# -*-coding:utf-8 -*-

from ocr_decompression import FileProcessing
from obs import ObsClient
from huaweicloudsdkcore.auth.credentials import BasicCredentials
from huaweicloudsdkocr.v1.region.ocr_region import OcrRegion
from huaweicloudsdkcore.exceptions import exceptions
from huaweicloudsdkivs.v2.region.ivs_region import IvsRegion
from huaweicloudsdkivs.v2 import Meta, IvsClient, DetectStandardByNameAndIdRequest, StandardReqDataByNameAndId, IvsStandardByNameAndIdRequestBodyData, IvsStandardByNameAndIdRequestBody
from huaweicloudsdkocr.v1 import OcrClient, RecognizeIdCardRequest, IdCardRequestBody, RecognizeDriverLicenseRequest, DriverLicenseRequestBody, RecognizeVehicleLicenseRequest, RecognizeTransportationLicenseRequest, VehicleLicenseRequestBody, TransportationLicenseRequestBody, QualificationCertificateRequestBody, RecognizeQualificationCertificateRequest
from urllib.parse import unquote_plus

import os
import traceback
import datetime
import base64
import random
import string

default_region = 'cn-north-4'

LOCAL_MOUNT_PATH = '/tmp/'
LETTERS = string.ascii_letters


def check_configuration(context):
    region = context.getUserData('region')
    if not region:
        return 'region is not configured'

    obs_server = context.getUserData('obs_endpoint')
    if not obs_server:
        return 'obs_server is not configured'

    result_bucket = context.getUserData('result_bucket')
    if not result_bucket:
        return 'result_bucket is not configured'
    ak = context.getAccessKey().strip()
    sk = context.getSecretKey().strip()
    if not ak or not sk:
        ak = context.getUserData('ak', '').strip()
        sk = context.getUserData('sk', '').strip()
        if not ak or not sk:
            return 'ak or sk is empty'


def handler(event, context):
    log = context.getLogger()
    result = check_configuration(context)
    if result is not None:
        return result

    records = event.get("Records", None)
    if records is None:
        return 'Records is empty'

    file_processing = FileProcessing(context)
    optical_character = OpticalCharacter(context)
    try:
        for record in records:
            bucket_name, object_key = get_obs_obj_info(record)
            if object_key.endswith("zip"):
                file_processing.run(record)
                records = event.get("Records", None)
                for record in records:
                    bucket_name, object_key = get_obs_obj_info(record)
                    if object_key.endswith("zip"):
                        break
                    optical_character.process(record)
            else:
                optical_character.process(record)
    except Exception as e:
        log.error(e)
    finally:
        # 释放资源
        file_processing.obs_client.close()
        optical_character.obs_client.close()
    return 'complete!'


class OpticalCharacter:
    def __init__(self, context=None):
        self.logger = context.getLogger()
        self.obs_client = new_obs_client(context)
        self.ocr_client = new_ocr_client(context)
        self.image_bucket = None
        self.image_object_key = None
        self.image_url = None
        self.output_object_key = None
        self.result_bucket = context.getUserData('result_bucket')
        self.ivs_client = ivs_client(context)
        self.uuid = None
        self.verification_name = None
        self.verification_number = None
        self.face_picture_base64 = None
        self.download_dir = gen_local_download_path()

    def process(self, record):
        # 解析record
        self.analysis_record(record)
        if self.image_object_key.strip().lower().startswith("id") or self.image_object_key.strip().lower().startswith("face"):
            # 身份证识别
            response = self.id_identification()
            if response is None:
                return
            self.logger.info('ID card identification is completed.')
            # 返回的内容写入与图片同名的json文件中，然后上传到指定的输出桶
            self.upload_result(response.result)
            response_dict = response.result.to_dict()
            if not isinstance(response_dict["number"], str):
                raise TypeError("ID card identification failed.")
            self.verification_name = response_dict["name"]
            self.verification_number = response_dict["number"]
            # 人证核身
            if self.image_object_key.strip().lower().startswith("face"):
                if not self.verification_name:
                    raise ValueError('The ID card is not recognized. Please check the ID card picture.')
                (path, file) = os.path.split(self.image_object_key)
                download_path = self.download_dir + "/" + file
                #下载人脸图片到本地
                self.download_file_from_obs(self.image_bucket, self.image_object_key, download_path)
                #进行base64编码
                self.get_face_picture_base64(download_path)
                self.uuid = datetime.datetime.now()
                #调用人证核身服务
                response = self.verification_identity()
                if response is None:
                    return
                self.logger.info('Authentication completed.')
                self.upload_result(response.result)

        # 驾驶证识别
        elif self.image_object_key.strip().lower().startswith("driver"):
            response = self.driver_license()
            if response is None:
                return
            self.logger.info('Driver_license identification is completed.')
            self.upload_result(response.result)

        # 行驶证识别
        elif self.image_object_key.strip().lower().startswith("vehicle"):
            response = self.vehicle_license()
            if response is None:
                return
            self.logger.info('Vehicle license identification is completed.')
            self.upload_result(response.result)

        # 道路运输证识别
        elif self.image_object_key.strip().lower().startswith("transport"):
            response = self.road_transport_certificate()
            if response is None:
                return
            self.logger.info('Road transport certificate identification is completed.')
            self.upload_result(response.result)

        # 道路运输从业资格证
        elif self.image_object_key.strip().lower().startswith("qualification"):
            response = self.road_transport_qualification_certificate()
            if response is None:
                return
            self.logger.info('Road transport qualification certificate identification is completed.')
            self.upload_result(response.result)
        else:
            return "Please name the picture correctly!", \
                   self.logger.error('Please name the picture correctly!')

    def id_identification(self):
        try:
            request = RecognizeIdCardRequest()
            request.body = IdCardRequestBody(
                detect_copy=True,
                detect_reproduce=True,
                return_text_location=True,
                return_verification=True,
                url=self.image_url
            )
            return self.ocr_client.recognize_id_card(request)
        except exceptions.ClientRequestException as e:
            self.logger.error(f"Failed to identify ID card:"
                              f"status_code：{e.status_code}, "
                              f"request_id:{e.request_id}, "
                              f"error_code:{e.error_code}. "
                              f"error_msg:{e.error_msg}")
            self.upload_result(e)

    def driver_license(self):
        try:
            request = RecognizeDriverLicenseRequest()
            request.body = DriverLicenseRequestBody(
                return_text_location=True,
                return_issuing_authority=True,
                url=self.image_url
            )
            return self.ocr_client.recognize_driver_license(request)
        except exceptions.ClientRequestException as e:
            self.logger.error(f"Failure to recognize driver's license："
                              f"status_code：{e.status_code}, "
                              f"request_id:{e.request_id}, "
                              f"error_code:{e.error_code}. "
                              f"error_msg:{e.error_msg}")
            self.upload_result(e)

    def vehicle_license(self):
        try:
            request = RecognizeVehicleLicenseRequest()
            request.body = VehicleLicenseRequestBody(
                return_text_location=True,
                return_issuing_authority=True,
                url=self.image_url
            )
            return self.ocr_client.recognize_vehicle_license(request)
        except exceptions.ClientRequestException as e:
            self.logger.error(f"Failure to identify driving license:"
                              f"status_code:{e.status_code}, "
                              f"request_id:{e.request_id}, "
                              f"error_code:{e.error_code}. "
                              f"error_msg:{e.error_msg}")
            self.upload_result(e)

    def road_transport_certificate(self):
        try:
            request = RecognizeTransportationLicenseRequest()
            request.body = TransportationLicenseRequestBody(
                url=self.image_url
            )
            return self.ocr_client.recognize_transportation_license(request)
        except exceptions.ClientRequestException as e:
            self.logger.error(f"Failure to identify road transport certificate:"
                              f"status_code:{e.status_code}, "
                              f"request_id:{e.request_id}, "
                              f"error_code:{e.error_code}. "
                              f"error_msg:{e.error_msg}")
            self.upload_result(e)

    def road_transport_qualification_certificate(self):
        try:
            request = RecognizeQualificationCertificateRequest()
            request.body = QualificationCertificateRequestBody(
                url=self.image_url
            )
            return self.ocr_client.recognize_qualification_certificate(request)
        except exceptions.ClientRequestException as e:
            self.logger.error(f"Failure to identify road transport qualification certificate:"
                              f"status_code:{e.status_code}, "
                              f"request_id:{e.request_id}, "
                              f"error_code:{e.error_code}. "
                              f"error_msg:{e.error_msg}")
            self.upload_result(e)

    def verification_identity(self):
        try:
            request = DetectStandardByNameAndIdRequest()
            listReqDataData = [
                StandardReqDataByNameAndId(
                    verification_name=self.verification_name,
                    verification_id=self.verification_number,
                    face_image=self.face_picture_base64
                )
            ]
            databody = IvsStandardByNameAndIdRequestBodyData(
                req_data=listReqDataData
            )
            metabody = Meta(
                uuid=self.uuid
            )
            request.body = IvsStandardByNameAndIdRequestBody(
                data=databody,
                meta=metabody
            )

            return self.ivs_client.detect_standard_by_name_and_id(request)
        except exceptions.ClientRequestException as e:
            self.logger.error(f"failed to verification_identity："
                              f"status_code：{e.status_code}, "
                              f"request_id:{e.request_id}, "
                              f"error_code:{e.error_code}. "
                              f"error_msg:{e.error_msg}")

    def get_face_picture_base64(self, download_result):
        with open(download_result, "rb")as f:
            uri = base64.b64encode(f.read())
            self.face_picture_base64 = uri.decode('utf-8')

    def upload_result(self, result):
        try:
            resp = self.obs_client.putContent(self.result_bucket,
                                              self.output_object_key,
                                              content=result)
            if resp.status < 300:
                self.logger.info('information upload bucket:' + self.result_bucket)
            else:
                self.logger.error("failed to upload result, "
                                  f"requestId：{resp.requestId} "
                                  f"errorCode：{resp.errorCode} "
                                  f"errorMessage：{resp.errorMessage}")
        except Exception as e:
            self.logger.error("failed to upload result, "
                              f"exception：{traceback.format_exc(e)}")

    def analysis_record(self, record):
        region = get_region(record)
        # 提取图片所在的obs桶和key
        (bucket_name, object_key) = get_obs_obj_info(record)
        self.image_bucket = bucket_name
        self.logger.info("input bucket_name: %s", bucket_name)
        self.image_object_key = unquote_plus(object_key)
        self.logger.info("input object: %s", self.image_object_key)
        # 拼接图片obs访问url
        self.image_url = "https://" + bucket_name + ".obs." + \
                         region + ".myhuaweicloud.com/" + self.image_object_key
        (path, filename) = os.path.split(self.image_object_key)
        (filename, _) = os.path.splitext(filename)
        # 处理结果输出对象
        self.output_object_key = path + filename + ".json"

    def download_file_from_obs(self, bucket, obj_name, download_path):
        self.logger.info(f'start to download object %s from obs %s to local %s',
                         obj_name, bucket, download_path)
        try:
            resp = self.obs_client.getObject(bucket, obj_name,
                                             downloadPath=download_path)
            if resp.status < 300:
                self.logger.info(
                    f'succeeded to download object %s from obs %s to local %s',
                    obj_name, bucket, download_path)
                return True
            else:
                self.logger.error(
                    f"failed to download object {obj_name} from obs {bucket}, "
                    f"errorCode:{resp.errorCode} errorMessage:{resp.errorMessage}")
        except:
            self.logger.error(
                f"failed to download file {obj_name} from obs bucket{bucket}, "
                f"exp:{traceback.format_exc()}")


def get_obs_obj_info(record):
    s3 = record['s3']
    return (s3['bucket']['name'], s3['object']['key'])


def new_obs_client(context):
    return ObsClient(
        access_key_id=context.getAccessKey(),
        secret_access_key=context.getSecretKey(),
        server=context.getUserData('obs_endpoint')
    )


def new_ocr_client(context):
    credentials = BasicCredentials(context.getAccessKey(),
                                   context.getSecretKey())
    return OcrClient.new_builder() \
        .with_credentials(credentials) \
        .with_region(OcrRegion.value_of(context.getUserData('region'))) \
        .build()


def get_region(record):
    if 'eventRegion' in record:
        return record.get("eventRegion", default_region)
    else:
        return record.get("awsRegion", default_region)


def ivs_client(context):
    credentials = BasicCredentials(context.getAccessKey(), context.getSecretKey())
    client_ivs = IvsClient.new_builder() \
        .with_credentials(credentials) \
        .with_region(IvsRegion.value_of(context.getUserData("region"))) \
        .build()
    return client_ivs


def gen_local_download_path():
    download_dir = LOCAL_MOUNT_PATH + ''.join(
        random.choice(LETTERS) for i in range(16))
    os.makedirs(download_dir)
    return download_dir
